/******************************************************************************
 * Copyright 2022 The AIR Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *****************************************************************************/

#pragma once

#include <memory>
#include <string>

#include <openssl/err.h>

namespace cw {
KeyPair::KeyPair(const std::string &filename, const std::string &password)
    : KeyPair(BIO_new_file(filename.c_str(), "r"), password) {}

KeyPair::KeyPair(const std::string &filename, tag_public_t)
    : KeyPair(BIO_new_file(filename.c_str(), "r"), tag_public_t{}) {}

KeyPair::KeyPair(const void *buffer, int buflen, const std::string &password)
    : KeyPair(BIO_new_mem_buf(buffer, buflen), password) {}

KeyPair::KeyPair(const void *buffer, int buflen, tag_public_t)
    : KeyPair(BIO_new_mem_buf(buffer, buflen), tag_public_t{}) {}

KeyPair::KeyPair(tag_rsa_t tag) {
  auto pkey = std::shared_ptr<EVP_PKEY>(EVP_PKEY_new(), EVP_PKEY_free);
  auto rsa = std::shared_ptr<RSA>(RSA_new(), RSA_free);
  auto bn_e = std::shared_ptr<BIGNUM>(BN_new(), BN_free);
  if (!pkey || !rsa || !bn_e) {
    std::cerr << "[CW] Failed to create ptrs" << std::endl;
    ERR_print_errors_fp(stderr);
    ERR_clear_error();
    return;
  }
  if (tag.bit_length < kRSAMinBitLength) {
    tag.bit_length = kRSAMinBitLength;
  }
  if (tag.bit_length > kRSAMaxBitLength) {
    tag.bit_length = kRSAMaxBitLength;
  }
  if (!BN_set_word(bn_e.get(), kRSAPublicExponent) ||
      !RSA_generate_key_ex(rsa.get(), tag.bit_length, bn_e.get(), nullptr) ||
      !EVP_PKEY_set1_RSA(pkey.get(), rsa.get())) {
    std::cerr << "[CW] Failed to generate RSA key" << std::endl;
    ERR_print_errors_fp(stderr);
    ERR_clear_error();
    return;
  }
  if (!Check(pkey.get())) {
    std::cerr << "[CW] Failed to check RSA key" << std::endl;
    ERR_print_errors_fp(stderr);
    ERR_clear_error();
    return;
  }
  pkey_ = pkey;
  is_public_ = false;
}

KeyPair::KeyPair(tag_ec_t tag) {
  auto pkey = std::shared_ptr<EVP_PKEY>(EVP_PKEY_new(), EVP_PKEY_free);
  auto ec =
      std::shared_ptr<EC_KEY>(EC_KEY_new_by_curve_name(tag.curve), EC_KEY_free);
  if (!pkey || !ec) {
    std::cerr << "[CW] Failed to create ptrs" << std::endl;
    ERR_print_errors_fp(stderr);
    ERR_clear_error();
    return;
  }
  EC_KEY_set_asn1_flag(ec.get(), OPENSSL_EC_NAMED_CURVE);
  if (!EC_KEY_generate_key(ec.get()) ||
      !EVP_PKEY_set1_EC_KEY(pkey.get(), ec.get())) {
    std::cerr << "[CW] Failed to generate EC key" << std::endl;
    ERR_print_errors_fp(stderr);
    ERR_clear_error();
    return;
  }
  if (!Check(pkey.get())) {
    std::cerr << "[CW] Failed to check EC key" << std::endl;
    ERR_print_errors_fp(stderr);
    ERR_clear_error();
    return;
  }
  pkey_ = pkey;
  is_public_ = false;
}

KeyPair::KeyPair(RSA *rsa) {
  if (!rsa) {
    std::cerr << "[CW] rsa is NULL" << std::endl;
    return;
  }
  auto pkey = std::shared_ptr<EVP_PKEY>(EVP_PKEY_new(), EVP_PKEY_free);
  if (!pkey) {
    std::cerr << "[CW] Failed to create pkey" << std::endl;
    ERR_print_errors_fp(stderr);
    ERR_clear_error();
    return;
  }
  if (!EVP_PKEY_set1_RSA(pkey.get(), rsa)) {
    std::cerr << "[CW] Failed to set RSA key" << std::endl;
    ERR_print_errors_fp(stderr);
    ERR_clear_error();
  }
  pkey_ = pkey;
  is_public_ = !RSA_get0_d(rsa);
}

KeyPair::KeyPair(EC_KEY *ec_key) {
  if (!ec_key) {
    std::cerr << "[CW] ec_key is NULL" << std::endl;
    return;
  }
  auto pkey = std::shared_ptr<EVP_PKEY>(EVP_PKEY_new(), EVP_PKEY_free);
  if (!pkey) {
    std::cerr << "[CW] Failed to create pkey" << std::endl;
    ERR_print_errors_fp(stderr);
    ERR_clear_error();
    return;
  }
  if (!EVP_PKEY_set1_EC_KEY(pkey.get(), ec_key)) {
    std::cerr << "[CW] Failed to set EC key" << std::endl;
    ERR_print_errors_fp(stderr);
    ERR_clear_error();
  }
  pkey_ = pkey;
  is_public_ = !EC_KEY_get0_private_key(ec_key);
}

KeyPair::KeyPair(BIO *bio, const std::string &password) {
  if (!bio) {
    std::cerr << "[CW] BIO is NULL" << std::endl;
    return;
  }
  auto bio_ptr = std::shared_ptr<BIO>(bio, BIO_free_all);
  auto pkey = std::shared_ptr<EVP_PKEY>(
      PEM_read_bio_PrivateKey(bio_ptr.get(), nullptr, nullptr,
                              const_cast<char *>(password.c_str())),
      EVP_PKEY_free);
  if (!Check(pkey.get())) {
    std::cerr << "[CW] Failed to check private key" << std::endl;
    ERR_print_errors_fp(stderr);
    ERR_clear_error();
    return;
  }
  pkey_ = pkey;
  is_public_ = false;
}

KeyPair::KeyPair(BIO *bio, tag_public_t) {
  if (!bio) {
    std::cerr << "[CW] BIO is NULL" << std::endl;
    return;
  }
  auto bio_ptr = std::shared_ptr<BIO>(bio, BIO_free_all);
  auto pkey = std::shared_ptr<EVP_PKEY>(
      PEM_read_bio_PUBKEY(bio_ptr.get(), nullptr, nullptr, nullptr),
      EVP_PKEY_free);
  if (!Check(pkey.get(), true)) {
    std::cerr << "[CW] Failed to check public key" << std::endl;
    ERR_print_errors_fp(stderr);
    ERR_clear_error();
    return;
  }
  pkey_ = pkey;
  is_public_ = true;
}

bool KeyPair::IsPublic() const { return is_public_; }

const EVP_PKEY *KeyPair::NativeHandle() const { return pkey_.get(); }

int KeyPair::Id() const {
  if (!pkey_) {
    return 0;
  }
  return EVP_PKEY_id(pkey_.get());
}

int KeyPair::BaseId() const {
  if (!pkey_) {
    return 0;
  }
  return EVP_PKEY_base_id(pkey_.get());
}

std::string KeyPair::PemString(const std::string &password) const {
  if (!pkey_) {
    std::cerr << "[CW] Invalid KeyPair" << std::endl;
    return "";
  }
  auto bio = std::shared_ptr<BIO>(BIO_new(BIO_s_mem()), BIO_free_all);
  if (!bio) {
    std::cerr << "[CW] Failed to create mem bio" << std::endl;
    ERR_print_errors_fp(stderr);
    ERR_clear_error();
    return "";
  }
  if (!password.empty()) {
    if (!PEM_write_bio_PrivateKey(bio.get(), pkey_.get(), EVP_aes_256_cbc(),
                                  reinterpret_cast<unsigned char *>(
                                      const_cast<char *>(password.c_str())),
                                  static_cast<int>(password.length()), nullptr,
                                  nullptr)) {
      std::cerr << "[CW] Failed to write key" << std::endl;
      ERR_print_errors_fp(stderr);
      ERR_clear_error();
      return "";
    }
  } else {
    if (!PEM_write_bio_PrivateKey(bio.get(), pkey_.get(), nullptr, nullptr, 0,
                                  nullptr, nullptr)) {
      std::cerr << "[CW] Failed to write key" << std::endl;
      ERR_print_errors_fp(stderr);
      ERR_clear_error();
      return "";
    }
  }
  char *data = nullptr;
  auto data_size = BIO_get_mem_data(bio.get(), &data);
  return {data, static_cast<size_t>(data_size)};
}

std::string KeyPair::PemString(tag_public_t) const {
  if (!pkey_) {
    std::cerr << "[CW] Invalid KeyPair" << std::endl;
    return "";
  }
  auto bio = std::shared_ptr<BIO>(BIO_new(BIO_s_mem()), BIO_free_all);
  if (!PEM_write_bio_PUBKEY(bio.get(), pkey_.get())) {
    std::cerr << "[CW] Failed to write key" << std::endl;
    ERR_print_errors_fp(stderr);
    ERR_clear_error();
    return "";
  }
  char *data = nullptr;
  auto data_size = BIO_get_mem_data(bio.get(), &data);
  return {data, static_cast<size_t>(data_size)};
}

bool KeyPair::Check(EVP_PKEY *pkey, bool is_public) {
  if (!pkey) {
    return false;
  }
  auto ctx = std::shared_ptr<EVP_PKEY_CTX>(EVP_PKEY_CTX_new(pkey, nullptr),
                                           EVP_PKEY_CTX_free);
  if (!ctx) {
    return false;
  }
  if (is_public) {
    return EVP_PKEY_public_check(ctx.get());
  }
  return EVP_PKEY_check(ctx.get());
}

}  // namespace cw
